function [net, info] = proj_classification()
%code by Lamyaa Sadouk, FST Settat
%based on the MNIST and CIFAR examples from MatConvNet

run matconvnet-1.0-beta25/matlab/vl_setupnn ; %run('matconvnet-1.0-beta16', 'matlab', 'vl_setupnn.m') ;

opts.learningRate = 0.01; % CNN-Mnist (0.0001) / MLPs: when SGD, 0.01 for all except yeast and when Adam  0.001 for all but mnist(0.0001) for adam 

opts.continue = false;

%GPU support is off by default.
opts.gpus = [1] ; %set it to [] is on cpu mode, and to [1] if on one gpu device

%opts.solver = @adam; % uncomment if optimizer is Adam. Otherwise, if using SGD then comment the line.

method =input('Please select the method for handling imbalanced data (o)data pre-processing: Oversampling, (u)data pre-processing: Undersampling, (n)nothing  ', 's');

lambda =input('Please enter the loss (0)log, (1)CS log, (2)msHinge, (3)CS msHinge, (4)L2, (5)CS-ST L2, (6)our CS L2,(7)CS_sum L2,(8)sq.hinge,(9)CS sq.hinge,(10)L2 estimate,(11)CS L2 estimate,(12)cub.hinge,(13)CS cub.hinge ');

opts.weight_parameter =input('Please set the weighting parameter - Example: (2)for MLPs and (50)for CNNs ');

kfold = 3; % kfold =input('Please enter the k-fold (k-1 for training & 1 for testing)_(0 for testing):  '); % k-fold = 9 
dataset =input('Please select the dataset:(ionosphere)/("pid" - Pima Indians Diabetes)/(WP_Breast_Cancer)/(SPECTF_Heart)/(yeast_8l)/(car)/(satimage)/(thyroid)/(mnist10)/(mnist30)/(mnist40)/(mnist50) ', 's');

%opts.expDir is where trained networks and plots are saved.
opts.expDir = fullfile('result_nets',strcat('data_', dataset,'_r', ...
     int2str(lambda),'_','multiclass', '_', method,'SGD' )) ; % 'newLossAdam'

% --------------------------------------------------------------------
%                                                         Prepare data
% --------------------------------------------------------------------
imdb_filename = fullfile('data_preprocessed', strcat('imdb_', dataset,'_c','.mat')); 
if exist(imdb_filename, 'file')
    load(imdb_filename) ; %  imdb = load(imdb_filename) ;  %save(imdb_filename, '-struct', 'imdb') ;
else
  imdb = setup_data(kfold, dataset);
 end

% --------------------------------------------------------------------
%                                                         Prepare model
%                                                         architecture
% --------------------------------------------------------------------
% specify the network architecture w/ cnn_init function
opts.numEpochs = 30; %65-thyroid %55 satimage %70-car %100-yeast % 110-glass %30for breastCancer %30 for heart and pid %20 for ionosphere %numel(opts.learningRate); %50 TO BE CHOSEN BASED ON THE DATASET
nb_features = size(imdb.images.data,1); % # of attributes
net = cnn_init_classification(imdb, lambda, dataset);  


%% -------------------------------------------------------------------
%                                                 Prepare model
%                                                 distribution
% --------------------------------------------------------------------
label =imdb.images.labels;

if lambda ==0 || lambda ==2 || lambda ==4
    pd_model = [];
else % lambda ==1,3,5,6,7% kernel distribution into the cost C
    [pd_model,~] = hist(label,unique(label));  %pd_model = fitdist(label(:),'kernel'); %,'Kernel', 'epanechnikov');
    pd_model = 1 - pd_model(:) ./ max(pd_model);
end
% pf for the chosen model (kernel, normal or nothing[])
opts.pd_model = pd_model;
% maximum value of pdf for the chosen model
% opts.pd_model_max = pd_model_max; % for code cnn_train_relevance


% --------------------------------------------------------------------
%                                                      Balance data if
%                                                      selected by user
% --------------------------------------------------------------------
if(isequal(method,'o') || isequal(method,'u'))
    imdb = balance_data(method, imdb, dataset);
end

%% -------------------------------------------------------------------
%                                                                Train
% --------------------------------------------------------------------
%1. setup the batch size
%opts.batchSize is the number of training images in each batch.

if length(imdb.images.labels) > 100000
opts.batchSize = 300 ;
elseif length(imdb.images.labels) > 50000
opts.batchSize = 150;
elseif length(imdb.images.labels) > 20000
    opts.batchSize = 100;
elseif length(imdb.images.labels) > 12000
     opts.batchSize = 50;
elseif length(imdb.images.labels) > 4800 %for mnist1,mnist10,mnist30,mnist50
     opts.batchSize = 30;      % just added, b4 was 10 for all <12000
elseif length(imdb.images.labels) > 1000
    opts.batchSize = 10;
else
    opts.batchSize = 5;
end

%opts.errorFunction = 'binary'; % 'multiclass' 'euclideanloss'
[net, info] = cnn_train_c_relevance_2(net, imdb, @getBatch, opts, ...
'val', find(imdb.images.set == 2)) ; % cnn_train_r_relevance % cnn_train_r_max % cnn_train_r

[max_gmean ind_max_gmean] = max([info.val(:).Gmean]);
fprintf('Highest G-mean is %f (%d)\n',max_gmean,ind_max_gmean )


end

% --------------------------------------------------------------------
function [im, labels] = getBatch(imdb, batch)
%getBatch is called by cnn_train.
%'imdb' is the image database.
%'batch' is the indices of the images chosen for this batch.
%'im' is the height x width x channels x num_images stack of images. 
%'labels' indicates the ground truth category of each image.
%N = length(batch);
im = imdb.images.data(:,:,:,batch) ;
labels = imdb.images.labels(1,batch) ;

end


